import copy
import numpy as np
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader, TensorDataset, random_split, RandomSampler
from GradientVariance import GradientVariance
from BootstrapLoss import BootstrapLoss
from Trace import LimitNeighborLossScaledTrace
import time

def GD(ini_model, x_train, y_train, x_test, y_test, d, N_train, learningrate, eps, bs, Replacement, seed, K, KBS, EpTimes, ComputeGV, ComputeBL, ComputeSingleHessian, z_i, z_i_label):
#def SGD(learningrate=LearningRate, eps=epochs, bs=train_bs, seed=ManualSeed):
    starting_time = time.time()
    torch.manual_seed(seed)
    model = copy.deepcopy(ini_model)
    ites = int(eps * EpTimes)
    dim = sum(p.numel() for p in model.parameters())
    ModelTraj = torch.zeros((ites, dim))
    TrainLosses = np.zeros(ites)
    TestLosses = np.zeros(ites)
    GradientVariances = np.zeros(ites)
    GradientNorms = np.zeros(ites)
    BootstrapLosses = np.zeros(ites)
    Products = np.zeros(ites)
    Frobeniuses = np.zeros(ites)
    HessianTraces = np.zeros(ites)
    CovarianceTraces = np.zeros(ites)
    AccProducts = np.zeros(ites)
    Hessian = torch.zeros((dim, dim))

    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=learningrate)

    #Converting the training inputs and labels to Variable
    if False:#torch.cuda.is_available():
        inputs = Variable(torch.from_numpy(x_train).cuda())
        labels = Variable(torch.from_numpy(y_train).cuda())
    else:
        inputs = Variable(torch.from_numpy(x_train))
        labels = Variable(torch.from_numpy(y_train))
    dataset = TensorDataset(inputs, labels)
    train_dataset = dataset
    # print(inputs)
    #print(dataset[0])

    # Converting the test inputs and labels to Variable
    if False:#torch.cuda.is_available():
        test_inputs = Variable(torch.from_numpy(x_test).cuda())
        test_labels = Variable(torch.from_numpy(y_test).cuda())
    else:
        test_inputs = Variable(torch.from_numpy(x_test))
        test_labels = Variable(torch.from_numpy(y_test))

    # Initialize the accumulated covariance matrix
    AccCovariance = torch.zeros(dim, dim)
    for epoch in range(eps * EpTimes):
        #Clear gradient buffers because we don't want any gradient from previous epoch to carry forward, dont want to cummulate gradients
        optimizer.zero_grad()
        #print("i={}".format(i))
        #i+=1
        #Get output from the model, given the inputs
        y_pred = model(inputs)

        #Get loss for the predicted output
        loss = criterion(y_pred, labels)
        #print(loss)
        #Get gradients w.r.t to parameters
        loss.backward()

        #Update parameters
        optimizer.step()

        # Save the model trajectory
        Parameter = torch.tensor([])
        for p in model.parameters():
            if p.requires_grad:
                Parameter = torch.cat((Parameter, p.view(1, -1)), 1)
        ModelTraj[epoch] = Parameter
        #ModelTraj[epoch + 1] = model.linear.weight.clone()
        #elif Model == 'ReLUNet':
        #    ModelTraj[epoch + 1] = model.layers[0].weight.clone()
        train_loss = criterion(model(inputs), labels).item()
        optimizer.zero_grad()
        TrainLosses[epoch] = train_loss
        test_loss = criterion(model(test_inputs), test_labels).item()
        optimizer.zero_grad()
        TestLosses[epoch] = test_loss
        if epoch % K == 0 and ComputeGV == True:
            GV, GN = GradientVariance(model, inputs, labels, N_train)
            GradientVariances[epoch:epoch+K] = GV
            GradientNorms[epoch:epoch+K] = GN
            Product, Frobenius, HessianTrace, CovarianceTrace, Hessian, NewCovariance = LimitNeighborLossScaledTrace(test_model=model, inputs=x_train, labels=y_train, d=d,
                                                   N_train=N_train, radiuses=[10e-12])
            Products[epoch:epoch + K] = Product
            Frobeniuses[epoch:epoch + K] = Frobenius
            HessianTraces[epoch:epoch + K] = HessianTrace
            CovarianceTraces[epoch:epoch + K] = CovarianceTrace
            AccCovariance += NewCovariance
            AccProducts[epoch:epoch + K] = torch.trace(torch.mm(Hessian, AccCovariance))
        if epoch % KBS == 0 and ComputeBL == True:
            BL = BootstrapLoss(model, inputs, labels, N_train, 1)
            BootstrapLosses[epoch:epoch + KBS] = BL

    print("GD AccCovariance Trace is {}".format(torch.trace(AccCovariance)))
    '''
    Paras = []
    for p in model.parameters():
        Paras.append(p)
    Paras = tuple(Paras)
    y_pred = model(inputs)
    Hessian = torch.autograd.functional.hessian(criterion(y_pred, labels), Paras)
    '''
    # Allocate Hessian size
    H = torch.zeros((dim, dim))
    if ComputeGV == True:
        # Calculate Jacobian w.r.t. model parameters
        y_pred = model(inputs)
        loss = criterion(y_pred, labels)
        J = torch.autograd.grad(loss, list(model.parameters()), create_graph=True)
        J = torch.cat([e.flatten() for e in J]) # flatten
        # Fill in Hessian
        for i in range(dim):
            result = torch.autograd.grad(J[i], list(model.parameters()), retain_graph=True)
            H[i] = torch.cat([r.flatten() for r in result]) # flatten
    elif ComputeSingleHessian == True:
        # Calculate Jacobian w.r.t. model parameters
        z_i = Variable(torch.from_numpy(z_i))
        z_i_label = Variable(torch.from_numpy(z_i_label))
        y_pred = model(z_i)
        loss = criterion(y_pred, z_i_label)
        J = torch.autograd.grad(loss, list(model.parameters()), create_graph=True)
        J = torch.cat([e.flatten() for e in J])  # flatten
        # Fill in Hessian
        for i in range(dim):
            result = torch.autograd.grad(J[i], list(model.parameters()), retain_graph=True)
            H[i] = torch.cat([r.flatten() for r in result])  # flatten
    print('GD runtime is {}'.format(time.time()-starting_time))
    return model, ModelTraj, TrainLosses, TestLosses, GradientVariances, GradientNorms, BootstrapLosses, Products, Frobeniuses, HessianTraces, CovarianceTraces, AccProducts, H